{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating a custom network\n",
    "In the previous section, we learned how to create A2C using stable baselines. Instead of\n",
    "using the default network, can we customize the network architecture? Yes! With a stable\n",
    "baseline, we can also use our own custom architecture. Let's see how to do that. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "import gym\n",
    "from stable_baselines.common.policies import MlpPolicy\n",
    "from stable_baselines.common.vec_env import DummyVecEnv\n",
    "from stable_baselines.common.evaluation import evaluate_policy\n",
    "from stable_baselines import A2C\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create the lunar lander environment using gym:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym.make('LunarLander-v2')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's use the dummy vectorized environment, we learned that in the dummy vectorized\n",
    "environment, we run each environment in the same process:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = DummyVecEnv([lambda: env])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we can define our custom policy (custom network) as shown below. As we can\n",
    "observe in the below code, we are passing net_arch=[dict(pi=[128, 128, 128],\n",
    "vf=[128, 128, 128])], which implies our network architecture. pi implies the\n",
    "architecture of the policy network and vf implies the architecture of value network: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomPolicy(FeedForwardPolicy):\n",
    "    def __init__(self, *args, **kargs):\n",
    "        super(CustomPolicy, self).__init__(*args, **kargs,\n",
    "                                           net_arch=[dict(pi=[128, 128, 128], vf=[128, 128, 128])],\n",
    "                                           feature_extraction=\"mlp\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create the agent:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent = A2C(CustomPolicy, env, ent_coef=0.1, verbose=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we can train the agent as usual:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent.learn(total_timesteps=250000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After training, we can evaluate our agent by looking at the mean rewards:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_reward, n_steps = evaluate_policy(agent, agent.get_env(),\n",
    "n_eval_episodes=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also have a look at how our trained agent performs in the environment:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "state = env.reset()\n",
    "while True:\n",
    "    action, _states = agent.predict(state)\n",
    "    next_state, reward, done, info = env.step(action)\n",
    "    state = next_state\n",
    "    env.render()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
